{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    "### https://towardsdatascience.com/logistic-regression-using-python-sklearn-numpy-mnist-handwriting-recognition-matplotlib-a6b31e2b166a"
   ],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "source": [
    "import sys\n",
    "from pathlib import Path\n",
    "curr_path = str(Path().absolute())\n",
    "parent_path = str(Path().absolute().parent)\n",
    "sys.path.append(parent_path) # add current terminal path to sys.path\n",
    "\n",
    "from Mnist.load_data import load_local_mnist\n",
    "\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "from sklearn.metrics import classification_report"
   ],
   "outputs": [],
   "metadata": {}
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "source": [
    "(X_train, y_train), (X_test, y_test) = load_local_mnist(normalize = False,one_hot = False)\n",
    "\n",
    "X_train, y_train= X_train[:2000], y_train[:2000] \n",
    "X_test, y_test = X_test[:200],y_test[:200]\n",
    "\n",
    "# solver：即使用的优化器，lbfgs：拟牛顿法， sag：随机梯度下降\n",
    "model = LogisticRegression(solver='lbfgs', max_iter=500) # lbfgs：拟牛顿法\n",
    "model.fit(X_train, y_train)\n",
    "y_pred = model.predict(X_test)\n",
    "print(classification_report(y_test, y_pred)) # 打印报告"
   ],
   "outputs": [
    {
     "output_type": "stream",
     "name": "stdout",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       0.94      1.00      0.97        17\n",
      "           1       1.00      0.93      0.96        28\n",
      "           2       0.76      0.81      0.79        16\n",
      "           3       0.88      0.94      0.91        16\n",
      "           4       0.93      0.89      0.91        28\n",
      "           5       0.94      0.85      0.89        20\n",
      "           6       0.90      0.90      0.90        20\n",
      "           7       1.00      0.88      0.93        24\n",
      "           8       0.91      1.00      0.95        10\n",
      "           9       0.84      1.00      0.91        21\n",
      "\n",
      "    accuracy                           0.92       200\n",
      "   macro avg       0.91      0.92      0.91       200\n",
      "weighted avg       0.92      0.92      0.92       200\n",
      "\n"
     ]
    }
   ],
   "metadata": {}
  }
 ],
 "metadata": {
  "orig_nbformat": 4,
  "language_info": {
   "name": "python",
   "version": "3.7.11",
   "mimetype": "text/x-python",
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "pygments_lexer": "ipython3",
   "nbconvert_exporter": "python",
   "file_extension": ".py"
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3.7.11 64-bit ('py37': conda)"
  },
  "interpreter": {
   "hash": "42c9bd6a7bb2602a627a03adc5d7029ff8a60f7e68b58a1beb81d0b4574fea3d"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}